Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding support for LGBM_BoosterUpdateOneIterCustom #114

Merged
merged 3 commits into from
Nov 29, 2021

Conversation

yaxxie
Copy link
Contributor

@yaxxie yaxxie commented Nov 18, 2021

As title

src/wrapper.jl Outdated
end
numdata = LGBM_DatasetGetNumData(first(bst.datasets))

if !(numdata == length(grads) == length(hessian))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be equal to numdata * LGBM_BoosterNumModelPerIteration()
https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterNumModelPerIteration
Please refer to microsoft/LightGBM#4815 (comment).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yaxxie I will wait for this before merging it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I have some modifications to make WRT this.


finished = LightGBM.LGBM_BoosterUpdateOneIterCustom(booster, randn(numdata), rand(numdata))
pred1 = LightGBM.LGBM_BoosterGetPredict(booster, 0)
# check both types of float work

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this randn(numdata) type of Vector{<:AbstractFloat}?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the test would fail with a MethodError if it was not

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is typeof Vector{Float64} which is the same as the following test Float32.(randn(numdata))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it's different -- the following test is Float32

The API takes only Float32 data, so if the user passes Float64, it needs to be converted. I test each way independently just to verify both entry points work as expected.

@yaxxie
Copy link
Contributor Author

yaxxie commented Nov 29, 2021

@FatemehTahavori I'm going to merge this now

@yaxxie yaxxie merged commit c40fc81 into master Nov 29, 2021
@yaxxie yaxxie deleted the add_custom_gradients_boosting branch November 29, 2021 11:39
"""
LGBM_BoosterUpdateOneIterCustom
Pass grads and 2nd derivatives corresponding to some custom loss function
grads and 2nd derivatives must be same cardinality as training data

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as training data * number of trees per iteration

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


if !((numdata*nummodels) == length(grads) == length(hessian))
throw(DimensionMismatch(
"Gradients sizes ($(length(grads)), $(length(hessian))) don't match training data size ($numdata) * ($nummodels)"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

number of trees per iteration * ($nummodels)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm? Not sure what you're expecting here, number of trees per iteration is the number of models, so the message would be like "doesnt match training data (22) * (3)" for example


finished = LightGBM.LGBM_BoosterUpdateOneIterCustom(booster, randn(numdata*num_class), rand(numdata*num_class))
pred1 = LightGBM.LGBM_BoosterGetPredict(booster, 0)
# check both types of float work

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remove this comment

test/ffi/booster.jl Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants